{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "front_matter": {
     "description": "AgentOptimizer is able to prompt LLMs to iteratively optimize function/skills of AutoGen agents according to the historical conversation and performance.",
     "tags": ["optimization", "tool/function"]
    }
   },
   "source": [
    "# AgentOptimizer: An Agentic Way to Train Your LLM Agent\n",
    "\n",
    "AutoGen offers conversable agents powered by LLM, tool, or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n",
    "Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
    "\n",
    "In traditional ML pipeline, we train a model by updating its parameter according to the loss on the training set, while in the era of LLM agents, how should we train an agent? Here, we take an initial step towards the agent training. Inspired by the [function calling](https://platform.openai.com/docs/guides/function-calling) capabilities provided by OpenAI, we draw an analogy between model parameters and agent functions/skills, and update agent’s functions/skills based on its historical performance on the training set. As an agentic way of training an agent, our approach help enhance the agents’ abilities without requiring access to the LLMs parameters.\n",
    "\n",
    "In this notebook, we introduce a new class, ‘AgentOptimizer’, which is able to improve the function list of one Assistant-UserProxy pair according to the historical conversation histories.\n",
    "This feature would support agents in improving their ability to solve problems of the same type as previous tasks.\n",
    "Specifically, given a set of training data, AgentOptimizer would iteratively prompt the LLM to optimize the existing function list of the AssistantAgent and UserProxyAgent with code implementation if necessary. It also includes two strategies, roll-back, and early-stop, to streamline the training process.\n",
    "In the example scenario, we test the proposed AgentOptimizer in solving problems from the [MATH dataset](https://github.com/hendrycks/math). \n",
    "\n",
    "![AgentEval](../website/blog/2023-12-23-AgentOptimizer/img/agentoptimizer.png)\n",
    "\n",
    "More information could be found in the [paper](https://arxiv.org/abs/2402.11359).\n",
    "\n",
    "Authors:\n",
    "- [Shaokun Zhang](https://github.com/skzhang1), Ph.D. student at the The Pennsylvania State University\n",
    "- [Jieyu Zhang](https://jieyuz2.github.io), Ph.D. student at the University of Washington"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n",
    "from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer\n",
    "from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent\n",
    "from autogen.agentchat import Agent\n",
    "from openai import BadRequestError\n",
    "from autogen.code_utils import extract_code\n",
    "from autogen.math_utils import get_answer\n",
    "from autogen import config_list_from_json\n",
    "import autogen\n",
    "import json\n",
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MathUserProxy with function_call\n",
    "\n",
    "This agent is a customozied MathUserProxy inherits from its [partent class](https://github.com/microsoft/autogen/blob/main/autogen/agentchat/contrib/math_user_proxy_agent.py.).\n",
    "\n",
    "It supports using both function_call and python to solve math problems.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_termination_msg_mathchat(message):\n",
    "    \"\"\"Check if a message is a termination message.\"\"\"\n",
    "    if isinstance(message, dict):\n",
    "        message = message.get(\"content\")\n",
    "        if message is None:\n",
    "            return False\n",
    "    cb = extract_code(message)\n",
    "    contain_code = False\n",
    "    for c in cb:\n",
    "        if c[0] == \"python\":\n",
    "            contain_code = True\n",
    "            break\n",
    "    if message.rstrip().find(\"TERMINATE\") >= 0:\n",
    "        return True\n",
    "    return not contain_code and get_answer(message) is not None and get_answer(message) != \"\"\n",
    "\n",
    "\n",
    "class MathUserProxyAgent(MathUserProxyAgent):\n",
    "    MAX_CONSECUTIVE_AUTO_REPLY = 15\n",
    "    DEFAULT_REPLY = \"Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\\\boxed{}.)\"\n",
    "    PROMPTS = \"\"\"Let's solve a math problem.\n",
    "Query requirements:\n",
    "You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
    "You can use packages like sympy to help you.\n",
    "You must follow the formats below to write your code:\n",
    "```python\n",
    "# your code\n",
    "```\n",
    "If some packages are missing, you could also suggest a code to install the corresponding package.\n",
    "\n",
    "Please follow this process:\n",
    "1. Solve the problem step by step (do not over-divide the steps).\n",
    "2. Take out any queries that can be asked through Python code (for example, any calculations or equations that can be calculated) and functions you know in the context of this conversation.\n",
    "\n",
    "Please\n",
    "(1) do not mix suggested Python codes and function calls in one step.\n",
    "(2) You MUST remember that you don’t have a function named \"python\" available.\n",
    "\n",
    "You must follow the formats below to write your Python code:\n",
    "```python\n",
    "# your code\n",
    "```\n",
    "\n",
    "3. Wait for me to give the results or wait for the executed results of the function call.\n",
    "4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
    "\n",
    "After all the queries are run and you get the answer, put the answer in \\\\boxed{}.\n",
    "\n",
    "Problem:\n",
    "\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: Optional[str] = \"MathChatAgent\",\n",
    "        is_termination_msg: Optional[Callable[[Dict], bool]] = is_termination_msg_mathchat,\n",
    "        human_input_mode: Optional[str] = \"NEVER\",\n",
    "        default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,\n",
    "        max_invalid_q_per_step=3,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        super().__init__(\n",
    "            name=name,\n",
    "            is_termination_msg=is_termination_msg,\n",
    "            human_input_mode=human_input_mode,\n",
    "            default_auto_reply=default_auto_reply,\n",
    "            max_invalid_q_per_step=max_invalid_q_per_step,\n",
    "            **kwargs,\n",
    "        )\n",
    "        del self._reply_func_list[2]\n",
    "        self.register_reply([Agent, None], MathUserProxyAgent._generate_math_reply, position=4)\n",
    "        del self._reply_func_list[3]\n",
    "        self.register_reply(\n",
    "            trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent.generate_function_call_reply, position=3\n",
    "        )\n",
    "        self.register_reply(\n",
    "            trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent._check_final_result, position=0\n",
    "        )\n",
    "\n",
    "        self.max_function_call_trial = 3\n",
    "        self.query = None\n",
    "        self.answer = None\n",
    "        self.is_correct = None\n",
    "\n",
    "    def generate_function_call_reply(\n",
    "        self,\n",
    "        messages: Optional[List[Dict]] = None,\n",
    "        sender: Optional[autogen.ConversableAgent] = None,\n",
    "        config: Optional[Any] = None,\n",
    "    ) -> Tuple[bool, Union[Dict, None]]:\n",
    "        \"\"\"Generate a reply using function call.\"\"\"\n",
    "        if messages is None:\n",
    "            messages = self._oai_messages[sender]\n",
    "        message = messages[-1]\n",
    "        if \"function_call\" in message:\n",
    "            is_exec_success, func_return = self.execute_function(message[\"function_call\"])\n",
    "            if is_exec_success:\n",
    "                self.max_function_call_trial = 3\n",
    "                return True, func_return\n",
    "            else:\n",
    "                if self.max_function_call_trial == 0:\n",
    "                    error_message = func_return[\"content\"]\n",
    "                    self.max_function_call_trial = 3\n",
    "                    return (\n",
    "                        True,\n",
    "                        \"The func is executed failed many times. \"\n",
    "                        + error_message\n",
    "                        + \". Please directly reply me with TERMINATE. We need to terminate the conversation.\",\n",
    "                    )\n",
    "                else:\n",
    "                    revise_prompt = \"You may make a wrong function call (It may due the arguments you provided doesn't fit the function arguments like missing required positional argument). \\\n",
    "                    If you think this error occurs due to you make a wrong function arguments input and you could make it success, please try to call this function again using the correct arguments. \\\n",
    "                    Otherwise, the error may be caused by the function itself. Please directly reply me with TERMINATE. We need to terminate the conversation. \"\n",
    "                    error_message = func_return[\"content\"]\n",
    "                    return True, \"The func is executed failed.\" + error_message + revise_prompt\n",
    "        return False, None\n",
    "\n",
    "    def initiate_chat(\n",
    "        self,\n",
    "        recipient,\n",
    "        answer: None,\n",
    "        silent: Optional[bool] = False,\n",
    "        **context,\n",
    "    ):\n",
    "        self.query = context[\"problem\"]\n",
    "        if not isinstance(answer, str):\n",
    "            answer = str(answer)\n",
    "            if answer.endswith(\".0\"):\n",
    "                answer = answer[:-2]\n",
    "            self._answer = answer\n",
    "        else:\n",
    "            self._answer = answer\n",
    "\n",
    "        self.is_correct = None\n",
    "\n",
    "        self._prepare_chat(recipient, True)\n",
    "        error_message = None\n",
    "        try:\n",
    "            prompt = self.PROMPTS + context[\"problem\"]\n",
    "            self.send(prompt, recipient, silent=silent)\n",
    "        except BadRequestError as e:\n",
    "            error_message = str(e)\n",
    "            self.is_correct = 0\n",
    "            print(\"error information: {}\".format(error_message))\n",
    "\n",
    "        recipient.reset()\n",
    "        is_correct = copy.deepcopy(self.is_correct)\n",
    "        self._reset()\n",
    "        return is_correct\n",
    "\n",
    "    def _check_final_result(\n",
    "        self,\n",
    "        messages: Optional[List[Dict]] = None,\n",
    "        sender: Optional[autogen.Agent] = None,\n",
    "        config: Optional[Any] = None,\n",
    "    ):\n",
    "\n",
    "        messages = messages[-1]\n",
    "        if isinstance(messages, dict):\n",
    "            messages = messages.get(\"content\")\n",
    "            if messages is None:\n",
    "                return False, None\n",
    "\n",
    "        cb = extract_code(messages)\n",
    "        contain_code = False\n",
    "        for c in cb:\n",
    "            if c[0] == \"python\":\n",
    "                contain_code = True\n",
    "                break\n",
    "        if not contain_code and get_answer(messages) is not None and get_answer(messages) != \"\":\n",
    "            if get_answer(messages) == self._answer:\n",
    "                self.is_correct = 1\n",
    "                return True, \"The result is Correct. Please reply me with TERMINATE.\"\n",
    "            else:\n",
    "                self.is_correct = 0\n",
    "                return False, None\n",
    "        else:\n",
    "            return False, None\n",
    "\n",
    "    def _reset(self):\n",
    "        super()._reset()\n",
    "        self.max_function_call_trial = 3\n",
    "        self.is_correct = None\n",
    "        self.query = None\n",
    "        self.answer = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load dataset\n",
    "\n",
    "MATAH dataset contains 12,500 challenging competition mathematics problems. Each problem in MATH has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations. \n",
    "\n",
    "We strctly follow the [train](https://github.com/lifan-yuan/CRAFT/blob/main/tab_and_math/MATH/dataset/train/algebra.jsonl)/[test](https://github.com/lifan-yuan/CRAFT/blob/main/tab_and_math/MATH/dataset/algebra.jsonl) splits of [Craft](https://github.com/lifan-yuan/CRAFT). Please specific your own path to the dataset. Here we sample the first 10 algebra problems as examples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data, train_data = [], []\n",
    "with open(\"MATH/dataset/algebra.jsonl\", \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        test_data.append(json.loads(line))\n",
    "with open(\"MATH/dataset/train/algebra.jsonl\", \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        train_data.append(json.loads(line))\n",
    "test_data, train_data = test_data[0:10], train_data[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agents construction\n",
    "\n",
    "Constructing MathUserProxyAgent and AssistantAgent used in solving these problems. Here, we use gpt-4-1106-preview to construct the AssistantAgent. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
    "\n",
    "assistant = autogen.AssistantAgent(\n",
    "    name=\"assistant\",\n",
    "    system_message=\"You are a helpful assistant.\",\n",
    "    llm_config={\n",
    "        \"timeout\": 600,\n",
    "        \"seed\": 42,\n",
    "        \"config_list\": config_list,\n",
    "    },\n",
    ")\n",
    "user_proxy = MathUserProxyAgent(\n",
    "    name=\"mathproxyagent\",\n",
    "    human_input_mode=\"NEVER\",\n",
    "    code_execution_config={\"work_dir\": \"_output\", \"use_docker\": False},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test without agent optimizations \n",
    "\n",
    "Below is the code to get the performance without the agents optimization process. \n",
    "\n",
    "In this case, the AssistantAgent and MathUserProxyAgent don't have any function calls but solely solve problems with Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum = 0\n",
    "for index, query in enumerate(test_data):\n",
    "    is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query[\"answer\"], problem=query[\"question\"])\n",
    "    print(is_correct)\n",
    "    sum += is_correct\n",
    "success_rate_without_agent_training = sum / 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agent Training \n",
    "\n",
    "Then, we use the AgentOptimizer to iteratively optimize the agents by optimizing the function calls according to the historical conversations and performance.\n",
    "The AgentOptimizer yields register_for_llm and register_for_executor at each iteration, which are subsequently utilized to update the assistant and user_proxy agents, respectively. \n",
    "Here we optimize these two agents for ten epochs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCH = 10\n",
    "optimizer_model = \"gpt-4-1106-preview\"\n",
    "optimizer = AgentOptimizer(\n",
    "    max_actions_per_step=3, config_file_or_env=\"OAI_CONFIG_LIST\", optimizer_model=optimizer_model\n",
    ")\n",
    "for i in range(EPOCH):\n",
    "    for index, query in enumerate(train_data):\n",
    "        is_correct = user_proxy.initiate_chat(assistant, answer=query[\"answer\"], problem=query[\"question\"])\n",
    "        history = assistant.chat_messages_for_summary(user_proxy)\n",
    "        optimizer.record_one_conversation(history, is_satisfied=is_correct)\n",
    "    register_for_llm, register_for_exector = optimizer.step()\n",
    "    for item in register_for_llm:\n",
    "        assistant.update_function_signature(**item)\n",
    "    if len(register_for_exector.keys()) > 0:\n",
    "        user_proxy.register_function(function_map=register_for_exector)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test with agent optimizations \n",
    "\n",
    "After agent optimization, the agents obtained a list of functions from the AgentOptimizers after 10 optimization iterations as shown below.\n",
    "\n",
    "We then show the final performances with/without the agent optimization process. We observe the agents after optimization are obviously better.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum = 0\n",
    "for index, query in enumerate(test_data):\n",
    "    is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query[\"answer\"], problem=query[\"question\"])\n",
    "    sum += is_correct\n",
    "success_rate_with_agent_training = sum / 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------------------Functions learned------------------------------------------------\n",
      "evaluate_expression: Evaluate arithmetic or mathematical expressions provided as strings.\n",
      "\n",
      "calculate_compound_interest_principal: Calculate the principal amount needed to achieve a certain future value with quarterly compound interest.\n",
      "\n",
      "solve_linear_system: Solve a system of linear equations represented as coefficients and variables.\n",
      "\n",
      "------------------------------------------------Summary------------------------------------------------\n",
      "\n",
      "success_rate_without_agent_training: 60.0%\n",
      "\n",
      "success_rate_with_agent_training: 90.0%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"------------------------------------------------Functions learned------------------------------------------------\"\n",
    ")\n",
    "for func in assistant.llm_config[\"functions\"]:\n",
    "    print(func[\"name\"] + \": \" + func[\"description\"] + \"\\n\")\n",
    "print(\"------------------------------------------------Summary------------------------------------------------\\n\")\n",
    "print(\"success_rate_without_agent_training: {average}%\\n\".format(average=success_rate_without_agent_training * 100))\n",
    "print(\"success_rate_with_agent_training: {average}%\\n\".format(average=success_rate_with_agent_training * 100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.9",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
